-- drop function if exists sm_sc.fv_d_aggr_slice_avg_dloss_dindepdt(float[], int[]);
create or replace function sm_sc.fv_d_aggr_slice_avg_dloss_dindepdt
(
  -- i_indepdt_len    int[],
  -- i_depdt          float[],
  i_dloss_ddepdt   float[],
  i_cnt_per_grp    int[]
)
returns float[]
as
$$
declare   
  v_dloss_ddepdt_len_y    int         := array_length(i_dloss_ddepdt, 1);
  v_dloss_ddepdt_len_x    int         := array_length(i_dloss_ddepdt, 2);
  v_dloss_ddepdt_len_x3   int         := array_length(i_dloss_ddepdt, 3);
  v_dloss_ddepdt_len_x4   int         := array_length(i_dloss_ddepdt, 4);
  v_ret                   float[]     ;
  v_cur_y                 int         ;
  v_cur_x                 int         ;
  v_cur_x3                int         ;
  v_cur_x4                int         ;
  -- v_indepdt_len           int[]       := 
  --   (
  --     select 
  --       array_agg(array_length(i_indepdt, a_cur_dim) order by a_cur_dim) 
  --     from generate_series(1, array_ndims(i_indepdt)) tb_a_cur_dim(a_cur_dim)
  --   )
  -- ;
begin
  -- i_cnt_per_grp := 
  --   sm_sc.fv_coalesce
  --   (
  --     v_indepdt_len[ : array_length(v_indepdt_len, 1) - coalesce(array_length(i_cnt_per_grp, 1), 0)] || i_cnt_per_grp
  --   , v_indepdt_len
  --   )
  -- ;
  if current_setting('pg4ml._v_is_debug_check', true) = '1'
  then
    if array_ndims(i_cnt_per_grp) > 1
    then 
      raise exception 'unsupport ndims of i_cnt_per_grp > 1.';
    elsif array_ndims(i_dloss_ddepdt) <> array_length(i_cnt_per_grp, 1) 
    then 
      raise exception 'unmatch between ndims of i_dloss_ddepdt and length of i_cnt_per_grp.';
    -- elsif array_dims(i_depdt) <> array_dims(i_dloss_ddepdt)
    -- then 
    --   raise exception 'unmatch between dims of i_depdt and i_dloss_ddepdt.';
    end if;
  end if;
    
  if i_dloss_ddepdt is null
  then 
    return null;
    
  elsif array_length(i_cnt_per_grp, 1) = 1
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y] *` i_cnt_per_grp);
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      v_ret
        [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
      :=
        array_fill(1.0 :: float / sm_sc.fv_aggr_slice_prod(i_cnt_per_grp), i_cnt_per_grp) 
        *` 
        i_dloss_ddepdt
          [v_cur_y]
      ;        
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 2
  then  
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        v_ret
          [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
          [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
        :=
          array_fill(1.0 :: float / sm_sc.fv_aggr_slice_prod(i_cnt_per_grp), i_cnt_per_grp) 
          *` 
          i_dloss_ddepdt
            [v_cur_y]
            [v_cur_x]
        ;        
      end loop;
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 3
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        for v_cur_x3 in 1 .. v_dloss_ddepdt_len_x3
        loop 
          v_ret
            [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
            [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
            [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
          :=
            array_fill(1.0 :: float / sm_sc.fv_aggr_slice_prod(i_cnt_per_grp), i_cnt_per_grp) 
            *` 
            i_dloss_ddepdt
              [v_cur_y]
              [v_cur_x]
              [v_cur_x3]
          ;        
        end loop;
      end loop;
    end loop;
    
  elsif array_length(i_cnt_per_grp, 1) = 4
  then    
    v_ret := array_fill(null :: float, array[v_dloss_ddepdt_len_y, v_dloss_ddepdt_len_x, v_dloss_ddepdt_len_x3, v_dloss_ddepdt_len_x4] *` i_cnt_per_grp);  
    for v_cur_y in 1 .. v_dloss_ddepdt_len_y
    loop 
      for v_cur_x in 1 .. v_dloss_ddepdt_len_x
      loop 
        for v_cur_x3 in 1 .. v_dloss_ddepdt_len_x3
        loop 
          for v_cur_x4 in 1 .. v_dloss_ddepdt_len_x4
          loop 
            v_ret
              [(v_cur_y - 1) * i_cnt_per_grp[1] + 1 : v_cur_y * i_cnt_per_grp[1]]
              [(v_cur_x - 1) * i_cnt_per_grp[2] + 1 : v_cur_x * i_cnt_per_grp[2]]
              [(v_cur_x3 - 1) * i_cnt_per_grp[3] + 1 : v_cur_x3 * i_cnt_per_grp[3]]
              [(v_cur_x4 - 1) * i_cnt_per_grp[4] + 1 : v_cur_x4 * i_cnt_per_grp[4]]
            :=
              array_fill(1.0 :: float / sm_sc.fv_aggr_slice_prod(i_cnt_per_grp), i_cnt_per_grp) 
              *` 
              i_dloss_ddepdt
                [v_cur_y]
                [v_cur_x]
                [v_cur_x3]
                [v_cur_x4]
            ;        
          end loop;
        end loop;
      end loop;
    end loop;
    
  end if;
  
  return v_ret;
end
$$
language plpgsql stable
parallel safe
cost 100;

-- select 
--   sm_sc.fv_d_aggr_slice_avg_dloss_dindepdt
--   (
--     array[2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--   , array[2]
--   ) :: decimal[] ~=` 3

-- select 
--   sm_sc.fv_d_aggr_slice_avg_dloss_dindepdt
--   (
--     array
--     [
--       [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--     , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--     ]
--   , array[3, 2]
--   ) :: decimal[] ~=` 3

-- select 
--   sm_sc.fv_d_aggr_slice_avg_dloss_dindepdt
--   (
--     array
--     [
--       [
--         [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       ]
--     , [
--         [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--       ]
--     ]
--   , array[2, 2, 4]
--   ) :: decimal[] ~=` 3

-- select 
--   sm_sc.fv_d_aggr_slice_avg_dloss_dindepdt
--   (
--     array
--     [
--       [
--         [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       , [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       ]
--     , [
--         [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       , [
--           [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         , [2.3, 5.1, 8.2, 2.56, 3.33, -1.9]
--         ]
--       ]
--     ]
--   , array[2, 2, 4, 3]
--   ) :: decimal[] ~=` 3